import argparse
import builtins
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from typing import TextIO

import yaml
from sqlalchemy.orm import Session

from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.server.models import QuestionRequest
from danswer.utils.callbacks import MetricsHander


engine = get_sqlalchemy_engine()


@contextmanager
def redirect_print_to_file(file: TextIO) -> Any:
    original_print = builtins.print
    builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
    try:
        yield
    finally:
        builtins.print = original_print


def load_yaml(filepath: str) -> dict:
    with open(filepath, "r") as file:
        data = yaml.safe_load(file)
    return data


def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
    words = s.split()

    current_line: list[str] = []
    result_lines: list[str] = []
    current_length = 0
    for word in words:
        if len(word) > max_line_size:
            if current_line:
                result_lines.append(" ".join(current_line))
                current_line = []
                current_length = 0

            result_lines.append(word)
            continue

        if current_length + len(word) + len(current_line) > max_line_size:
            result_lines.append(" ".join(current_line))
            current_line = []
            current_length = 0

        current_line.append(word)
        current_length += len(word)

    if current_line:
        result_lines.append(" ".join(current_line))

    return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)


def get_answer_for_question(
    query: str, db_session: Session
) -> tuple[
    str | None,
    RetrievalMetricsContainer | None,
    RerankMetricsContainer | None,
    LLMMetricsContainer | None,
]:
    question = QuestionRequest(
        query=query,
        collection="danswer_index",
        use_keyword=False,
        filters=None,
        offset=None,
    )

    retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
    rerank_metrics = MetricsHander[RerankMetricsContainer]()
    llm_metrics = MetricsHander[LLMMetricsContainer]()

    answer = answer_qa_query(
        question=question,
        user=None,
        db_session=db_session,
        answer_generation_timeout=100,
        real_time_flow=False,
        enable_reflexion=False,
        retrieval_metrics_callback=retrieval_metrics.record_metric,
        rerank_metrics_callback=rerank_metrics.record_metric,
        llm_metrics_callback=llm_metrics.record_metric,
    )

    return (
        answer.answer,
        retrieval_metrics.metrics,
        rerank_metrics.metrics,
        llm_metrics.metrics,
    )


def _print_retrieval_metrics(
    metrics_container: RetrievalMetricsContainer, show_all: bool
) -> None:
    for ind, metric in enumerate(metrics_container.metrics):
        if not show_all and ind >= 10:
            break

        if ind != 0:
            print()  # for spacing purposes
        print(f"\tDocument: {metric.document_id}")
        print(f"\tLink: {metric.first_link or 'NA'}")
        section_start = metric.chunk_content_start.replace("\n", " ")
        print(f"\tSection Start: {section_start}")
        print(f"\tSimilarity Distance Metric: {metric.score}")


def _print_reranking_metrics(
    metrics_container: RerankMetricsContainer, show_all: bool
) -> None:
    # Printing the raw scores as they're more informational than post-norm/boosting
    for ind, metric in enumerate(metrics_container.metrics):
        if not show_all and ind >= 10:
            break

        if ind != 0:
            print()  # for spacing purposes
        print(f"\tDocument: {metric.document_id}")
        print(f"\tLink: {metric.first_link or 'NA'}")
        section_start = metric.chunk_content_start.replace("\n", " ")
        print(f"\tSection Start: {section_start}")
        print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")


def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None:
    print(f"\tPrompt Tokens: {metrics_container.prompt_tokens}")
    print(f"\tResponse Tokens: {metrics_container.response_tokens}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "regression_yaml",
        type=str,
        help="Path to the Questions YAML file.",
        default="./tests/regression/answer_quality/sample_questions.yaml",
        nargs="?",
    )
    parser.add_argument(
        "--real-time", action="store_true", help="Set to use the real-time flow."
    )
    parser.add_argument(
        "--discard-metrics",
        action="store_true",
        help="Set to not include metrics on search, rerank, and token counts.",
    )
    parser.add_argument(
        "--all-results",
        action="store_true",
        help="Set to not include more than the 10 top sections for search and reranking metrics.",
    )
    parser.add_argument(
        "--output",
        type=str,
        help="Path to the output results file.",
        default="./tests/regression/answer_quality/regression_results.txt",
    )
    args = parser.parse_args()

    questions_data = load_yaml(args.regression_yaml)

    with open(args.output, "w") as outfile:
        with redirect_print_to_file(outfile):
            print("Running Question Answering Flow")
            print(
                "Note that running metrics requires tokenizing all "
                "prompts/returns and slightly slows down inference."
            )
            print(
                "Also note that the text embedding model (bi-encoder) currently used is trained for "
                "relative distances, not absolute distances. Therefore cosine similarity values may all be > 0.5 "
                "even for poor matches"
            )

            with Session(engine, expire_on_commit=False) as db_session:
                for sample in questions_data["questions"]:
                    print(
                        f"Running Test for Question {sample['id']}: {sample['question']}"
                    )

                    start_time = datetime.now()
                    (
                        answer,
                        retrieval_metrics,
                        rerank_metrics,
                        llm_metrics,
                    ) = get_answer_for_question(sample["question"], db_session)
                    end_time = datetime.now()

                    print(f"====Duration: {end_time - start_time}====")
                    print(f"Question {sample['id']}:")
                    print(f'\t{sample["question"]}')
                    print("\nApproximate Expected Answer:")
                    print(f'\t{sample["expected_answer"]}')
                    print("\nActual Answer:")
                    print(
                        word_wrap(answer)
                        if answer
                        else "\tFailed, either crashed or refused to answer."
                    )
                    if not args.discard_metrics:
                        print("\nLLM Tokens Usage:")
                        if llm_metrics is None:
                            print("No LLM Metrics Available")
                        else:
                            _print_llm_metrics(llm_metrics)

                        print("\nRetrieval Metrics:")
                        if retrieval_metrics is None:
                            print("No Retrieval Metrics Available")
                        else:
                            _print_retrieval_metrics(
                                retrieval_metrics, show_all=args.all_results
                            )

                        print("\nReranking Metrics:")
                        if rerank_metrics is None:
                            print("No Reranking Metrics Available")
                        else:
                            _print_reranking_metrics(
                                rerank_metrics, show_all=args.all_results
                            )

                    print("\n\n", flush=True)
